{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_fld = 'input_fld_path'\n",
    "weight_file = 'kerasmodel_file_name located inside input_fld'\n",
    "num_output = 1\n",
    "write_graph_def_ascii_flag = True\n",
    "prefix_output_node_names_of_final_network = 'output_node'\n",
    "output_graph_name = 'constant_graph_weights.pb'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import load_model\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import os.path as osp\n",
    "from keras import backend as K\n",
    "\n",
    "output_fld = input_fld + 'tensorflow_model/'\n",
    "if not os.path.isdir(output_fld):\n",
    "    os.mkdir(output_fld)\n",
    "weight_file_path = osp.join(input_fld, weight_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load keras model and rename output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output nodes names are:  ['output_node0']\n"
     ]
    }
   ],
   "source": [
    "net_model = load_model(weight_file_path)\n",
    "K.set_learning_phase(0)\n",
    "\n",
    "pred = [None]*num_output\n",
    "pred_node_names = [None]*num_output\n",
    "for i in range(num_output):\n",
    "    pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)\n",
    "    pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])\n",
    "print('output nodes names are: ', pred_node_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### [optional] write graph definition in ascii"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saved the graph definition in ascii format at:  /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/only_the_graph_def.pb.ascii\n"
     ]
    }
   ],
   "source": [
    "sess = K.get_session()\n",
    "\n",
    "if write_graph_def_ascii_flag:\n",
    "    f = 'only_the_graph_def.pb.ascii'\n",
    "    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)\n",
    "    print('saved the graph definition in ascii format at: ', osp.join(output_fld, f))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### convert variables to constants and save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Froze 27 variables.\n",
      "Converted 27 variables to const ops.\n",
      "saved the constant graph (ready for inference) at:  /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/constant_graph_weights.pb\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.python.framework import graph_util\n",
    "from tensorflow.python.framework import graph_io\n",
    "constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)\n",
    "graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False)\n",
    "print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
